Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce New Lookup-Table(LUT)-Based Matrix Multiplication Method #10181

Open
wants to merge 12 commits into
base: master
Choose a base branch
from

Conversation

QingtaoLi1
Copy link

@QingtaoLi1 QingtaoLi1 commented Nov 5, 2024

This PR introduces a new efficient lookup-table(LUT)-based matrix multiplication method to speed up low-bit LLM inference, and adds a new tensor type named INT_N to support it. The method can provide up to 3~4× increase in end-to-end inference throughput and 70% reduction in energy consumption.

Unlike the existing quant-dequant methods, the LUT-based method directly supports mixed-precision-GEMM(mpGEMM) without dequantization. It uses bit-wise table lookup to eliminate multiplications and reduce additions required in matrix multiplication. In this PR, we propose the LUT method from T-MAC. As it utilizes the same number of lookup-tables as the weight bits, this LUT-based method provides a unified solution for mpGEMM and the kernels can scale linearly to the weight bit-width, instead of falling back to int8 or uint8 for all low-bit values under 8 bits.

We add a new data type INT_N as well as corresponding convert script to support the tensor layout needs of LUT kernels. The bit rate of INT_N depends on the scale group size of the model and/or its original data type. The scale group size alternatives are {32, 64, 128, 256}, in which >=64 will fully unveil the efficiency of T-MAC. For example, if the scale group size is 64, the model weights are 4 bits and the scales are F32, the bit rate will be (256 * 4 + 32) / 256 = 4.125 bpw.

image

How to Use It

Using T-MAC in llama.cpp is similar to using existing quantization methods and models, except for a few commands to compile the LUT kernels for the model to run and convert the models into data types that are currently supported, (for now, Q4_0, TQ1_0, TQ2_0 and INT_N). For Q4_0, TQ1_0 and TQ2_0, models in gguf format can directly run with T-MAC; while for INT_N, we support in convert_hf_to_gguf.py to convert HuggingFace models to INT_N. Compiling LUT kernels requires dependencies of T-MAC modules.

K-quant can be supported by T-MAC, which requires some engineering efforts. We plan it as a TODO item.

We have a dockerfile to setup T-MAC environment in Ubuntu-22.04 and a one-stop script to use T-MAC. Note that the script is only a wrapper for convenience and doesn't introduce a brand new way to build the project. Here are some examples:

export EXP_ROOT=<root_path_of_experiments>
cd T-MAC

# Convert HuggingFace llama-2-7B-2bit model into INT_N and run inference
huggingface-cli download ChenMnZ/Llama-2-7b-EfficientQAT-w2g128-GPTQ --local-dir $EXP_ROOT/models/Llama-2-7b-EfficientQAT-w2g128-GPTQ
./tools/all_in_one.sh $EXP_ROOT/models/Llama-2-7b-EfficientQAT-w2g128-GPTQ/ llama-2-7b-2bit int_n --convert-model

# Convert HuggingFace llama-2-7B-4bit model into INT_N and run inference
huggingface-cli download TheBloke/Llama-2-7B-Chat-GPTQ --local-dir $EXP_ROOT/models/Llama-2-7B-Chat-GPTQ
./tools/all_in_one.sh $EXP_ROOT/models/Llama-2-7B-Chat-GPTQ/ llama-2-7b-4bit int_n --convert-model

# Run GGUF TriLM-3.9B model. Renaming is just to fit our all_in_one script.
huggingface-cli download BoscoTheDog/TriLM_BitNet_3_9B_GGUF --local-dir $EXP_ROOT/models/TriLM_BitNet_3_9B_GGUF
mv $EXP_ROOT/models/TriLM_BitNet_3_9B_GGUF/TriLM_3.9B_Unpacked_quant_TQ1_0.gguf $EXP_ROOT/models/TriLM_BitNet_3_9B_GGUF/TriLM_BitNet_3_9B_GGUF.TQ1_0.gguf
./tools/all_in_one.sh $EXP_ROOT/models/TriLM_BitNet_3_9B_GGUF/ trilm-3.9b tq1_0

Speed

Update: for the latest numbers, see below.

We test this PR on an Intel i7-12700 and an Apple M2-Ultra. The numbers below are in token/s. For details of the model, see the next section.

  • 2bit results:
Apple M2 llama-2-7b TriLM-3.9b
Model Type Q2_K INT_N (2bit) TQ1_0 TQ2_0 TQ1_0 TQ2_0
Use T-MAC x x x
Threads 1 4.15 12.72 9.41 14.24 21.07 20.96
2 7.93 21.16 17.39 24.30 34.51 34.45
3 11.06 28.34 25.10 34.60 42.66 45.23
4 14.69 34.87 32.69 43.60 55.56 57.10
Intel i7-12700 llama-2-7b TriLM-3.9b
Model Type Q2_K INT_N (2bit) TQ1_0 TQ2_0 TQ1_0 TQ2_0
Use T-MAC x x x
Threads 1 4.20 6.26 6.93 10.62 11.46 11.19
2 7.87 11.01 12.75 18.81 19.68 19.30
3 10.91 13.66 16.79 24.22 25.46 25.36
4 13.15 14.84 21.95 29.10 30.49 31.20
  • 4bit results:
Apple M2 llama-2-7b
Model Type Q4_0 Q4_0 INT_N (4bit)
Use T-MAC x
Threads 1 5.48 6.72 6.50
2 10.09 11.50 11.21
3 14.52 15.77 15.70
4 18.91 20.36 20.00
Intel i7-12700 llama-2-7b
Model Type Q4_0 Q4_0 INT_N (4bit)
Use T-MAC x
Threads 1 2.90 3.57 3.59
2 5.52 6.88 6.49
3 7.56 8.71 8.63
4 8.75 10.15 9.79

Model size

The INT_N model uses F16 embedding and output weights, therefore the Q2_K and Q4_0 model here uses the same config for a fair comparison.
And we use pure Q2_K model here since the model size is very closed to INT_N (2bit).

Note that the block size of INT_N models here are 64 for 2bit and 128 for 4bit, and the scales are stored in F32 now.

llama-2-7b Q2_K INT_N (2bit) Q4_0 INT_N (4bit)
Model Size 2.47 GiB 2.37 GiB 3.88 GiB 3.69 GiB
Bits 2.625 bpw* 2.5 bpw 4.5 bpw 4.25 bpw

* we find that Q2_K is actually 2.625 bpw instead of 2.5625 described in #1684.

TriLM-3.9b TQ1_0 TQ2_0
Model Size 948.16 MiB 1112.70 MiB
Bits 1.6875 bpw 2.0625 bpw

Perplexity

See below.

Note

T-MAC has a public repo which includes llama.cpp as a third-party module. For changes inside llama.cpp, we can directly merge the changes, while the T-MAC modules will stay in that repo. It provides the capability to run arbitrary models with T-MAC. Without it, we can only run those supported by pre-built LUT kernels.

Our LUT-based method is used in the recently open-sourced bitnet.cpp repo which is built on llama.cpp. We can easily generate corresponding kernels and support their models.

Future Work

  • Support NPU inference acceleration.
  • Support F16 scales and zero points.
  • Automatically detect the matmul kernel shapes for compilation.

@github-actions github-actions bot added python python script changes ggml changes relating to the ggml tensor library for machine learning labels Nov 5, 2024
@JohannesGaessler
Copy link
Collaborator

Did you check the KL divergence of the new datatype using llama-perplexity?

@netrunnereve
Copy link
Collaborator

Did you check the KL divergence of the new datatype using llama-perplexity?

If we're just looking at 4 bit here Q4_0 has 1 FP16 scale per 32 weights. In the 4 bit INT_N example it's using 1 FP32 scale per 128 weights. I would imagine that perplexity would be worse than Q4_0 in this case.

In the 7B 4 bit example on the Intel there's only a 12% improvement between Q4_0 and INT_N with 4 threads. If I quickly hack ggml_vec_dot_q4_0_q8_0 to not perform the FP16 to FP32 conversion and only use one scale per 64 weights I already get a 5% improvement in performance. I imagine there will be a further increase by making it use one scale per 128 weights. This makes me wonder how much of the improvement comes from the LUT and how much comes from simply cutting down the number of scales (with associated loss in perplexity)?

@QingtaoLi1
Copy link
Author

QingtaoLi1 commented Nov 6, 2024

If we're just looking at 4 bit here Q4_0 has 1 FP16 scale per 32 weights. In the 4 bit INT_N example it's using 1 FP32 scale per 128 weights. I would imagine that perplexity would be worse than Q4_0 in this case.

@netrunnereve We haven't yet tested the perplexity and will do it later. I would like to clarify that INT_N is just a flexible low-bit data type whose perplexity depends on the model itself. It does not change the model precision. The 4bit INT_N model is quantized with GPTQ algorithm with block_size=128, so we follow it and use 1 scale per 128 weights. While the Q4_0 model is quantized with another algorithm with block_size=32, thus the corresponding INT_N has 1 scale per 32 weights.

@QingtaoLi1
Copy link
Author

Did you check the KL divergence of the new datatype using llama-perplexity?

@JohannesGaessler We will do it and put the perplexity results later.

@QingtaoLi1
Copy link
Author

QingtaoLi1 commented Nov 11, 2024

@JohannesGaessler @netrunnereve We've checked the PPL and KL divergence of some EfficientQAT models which can be supported by our LUT method, as well as some new speed numbers.

Overview

In our tests, the LUT method has the same PPL/KLD for Q4_0, TQ1_0 and TQ2_0 models as origin llama.cpp, while gains some speedup. The EfficientQAT model has even lower PPL and more speedup than the Q4_0 model. The pure 2bit EfficientQAT model has a higher PPL than Q2_K model mixed by 3bit weights, but is far better than pure Q2_K model and faster (on Apple M2 device, it goes to 3x speedup).

Note that in this reply, we are using EfficientQAT models in GPTQ format for our INT_N type which are different from those in the main PR description. Models here use 1 scale and 1 zero point for each block, while models in the main PR description use only 1 scale. Therefore, the model sizes here are a bit larger. But we can consider using F16 scales rather than F32, which will shrink the model by hundreds of MiB.

Perplexity

Update at Nov. 12: added EQAT models with quantized embedding and output weights. The configs are the same as Q2_K and Q4_0, that is, Q2_K embedding & Q6_K output for 2bit, and Q4_0 embedding & Q6_K output for 4bit.

We test some more Q2_K variants for potential comparisons of your interests. "pure" means using "--pure" option to force all weights being Q2_K, while "mixed" is the typical Q2_K model with some weights being Q3_K. "(F16)" means using F16 embedding and output weights, since GPTQ/EfficientQAT models are using F16. "(quantized)" means using quantized embedding and output weights according to 2bit and 4bit.

All tests here are conducted using llama-perplexity. We tested Q2_K_pure variants for many times, but the PPL was always that big.

And we'd like to clarify that INT_N is only a data format to store and compute with LUT method. The quantization method can be EfficientQAT, BitDistiller, BitNet, etc., and the perplexity depends on them.

Model - Llama-2-7b Model Size Use T-MAC ? Mean PPL(Q) Mean ln(PPL(Q)/PPL(base)) Mean KLD RMS Δp Same top p
F16 12.55 GiB x 5.7969 ± 0.03236
Q2_K_pure 2.06 GiB x 925.379260 ± 6.853744 5.073334 ± 0.007361 5.006663 ± 0.005624 60.073 ± 0.079 % 12.874 ± 0.082 %
Q2_K_mixed 2.36 GiB x 6.982406 ± 0.039852 0.186524 ± 0.001716 0.194111 ± 0.000911 13.815 ± 0.064 % 80.757 ± 0.096 %
EQAT-w2g64-INT_N (quantized) 2.40 GiB 7.367602 ± 0.042922 0.240223 ± 0.002012 0.237605 ± 0.001142 15.937 ± 0.072 % 79.364 ± 0.099 %
Q2_K_pure (F16) 2.47 GiB x 786.267639 ± 5.705384 4.910428 ± 0.007182 4.844392 ± 0.005474 59.783 ± 0.079 % 13.349 ± 0.083 %
Q2_K_mixed (F16) 2.71 GiB x 6.933891 ± 0.039533 0.179552 ± 0.001673 0.185774 ± 0.000868 13.537 ± 0.063 % 81.158 ± 0.096 %
EQAT-w2g64-INT_N (F16) 2.75 GiB 7.362946 ± 0.042912 0.239707 ± 0.002008 0.236106 ± 0.001139 15.854 ± 0.072 % 79.468 ± 0.099 %
Model - Llama-2-7b Model Size Use T-MAC ? Mean PPL(Q) Mean ln(PPL(Q)/PPL(base)) Mean KLD RMS Δp Same top p
Q4_0 3.56 GiB x 5.962298 ± 0.033478 0.028586 ± 0.000668 0.028983 ± 0.000299 4.757 ± 0.032 % 92.323 ± 0.065 %
Q4_0 3.56 GiB 5.962719 ± 0.033482 0.028657 ± 0.000668 0.028951 ± 0.000299 4.752 ± 0.032 % 92.359 ± 0.065 %
EQAT-w4g128-INT_N (quantized) 3.56 GiB 5.884062 ± 0.032854 0.015378 ± 0.000692 0.031924 ± 0.000259 4.972 ± 0.036 % 92.167 ± 0.066 %
Q4_0 (F16) 3.88 GiB x 5.961833 ± 0.033476 0.028508 ± 0.000664 0.028577 ± 0.000298 4.716 ± 0.032 % 92.390 ± 0.065 %
Q4_0 (F16) 3.88 GiB 5.960972 ± 0.033468 0.028364 ± 0.000663 0.028516 ± 0.000297 4.708 ± 0.032 % 92.392 ± 0.065 %
EQAT-w4g128-INT_N (F16) 3.88 GiB 5.881885 ± 0.032837 0.015008 ± 0.000688 0.031411 ± 0.000255 4.937 ± 0.036 % 92.316 ± 0.065 %
Model - TriLM-3.9b Model Size Use T-MAC ? Mean PPL(Q) Mean ln(PPL(Q)/PPL(base)) Mean KLD RMS Δp Same top p
F16 7.44 GiB x 11.1511 +/- 0.07852
TQ1_0 0.93 GiB x 11.162545 ± 0.078589 0.002493 ± 0.000221 0.001290 ± 0.000004 0.819 ± 0.004 % 97.793 ± 0.039 %
TQ2_0 1.09 GiB x 11.162545 ± 0.078589 0.002493 ± 0.000221 0.001290 ± 0.000004 0.819 ± 0.004 % 97.793 ± 0.039 %
TQ1_0 0.93 GiB 11.161108 ± 0.078579 0.002364 ± 0.000214 0.001111 ± 0.000003 0.751 ± 0.003 % 97.967 ± 0.037 %
TQ2_0 1.09 GiB 11.161108 ± 0.078579 0.002364 ± 0.000214 0.001111 ± 0.000003 0.751 ± 0.003 % 97.967 ± 0.037 %

Speed

All tests here are tg128.

We apologize that the previous numbers of Q4_0 with T-MAC are wrong. There is a further increase from block_size=32 (Q4_0) to block_size=128 (EQAT-w2g128). Thanks @netrunnereve for pointing it out.

Apple M2-Ultra llama-2-7b 2bit
Model Type Q2_K_pure Q2_K_mixed INT_N (quantized) Q2_K_pure (F16) Q2_K_mixed (F16) INT_N (F16)
Use T-MAC x x x x
Threads 1 4.59 ± 0.00 4.19 ± 0.12 13.15 ± 0.13 4.71 ± 0.04 4.23 ± 0.10 12.86 ± 0.04
2 8.72 ± 0.13 7.78 ± 0.12 20.32 ± 0.28 8.88 ± 0.09 7.98 ± 0.08 20.47 ± 0.10
3 12.61 ± 0.11 11.14 ± 0.09 27.63 ± 0.06 12.67 ± 0.16 11.48 ± 0.17 27.83 ± 0.16
4 16.38 ± 0.07 14.93 ± 0.12 33.82 ± 0.02 16.18 ± 0.07 14.80 ± 0.18 34.48 ± 0.20
Apple M2-Ultra llama-2-7b 4bit
Model Type Q4_0 Q4_0 INT_N (quantized) Q4_0 (F16) Q4_0 (F16) INT_N (f16)
Use T-MAC x x
Threads 1 5.35 ± 0.14 6.60 ± 0.03 7.02 ± 0.04 5.69 ± 0.03 6.38 ± 0.08 7.01 ± 0.02
2 9.78 ± 0.03 11.46 ± 0.03 12.54 ± 0.04 10.16 ± 0.09 10.97 ± 0.05 12.30 ± 0.01
3 14.37 ± 0.03 15.18 ± 0.04 17.44 ± 0.05 14.60 ± 0.15 15.86 ± 0.04 17.26 ± 0.05
4 18.49 ± 0.03 18.67 ± 1.23 21.75 ± 0.16 18.70 ± 0.05 18.99 ± 0.08 21.59 ± 0.13
Intel i7-12700 llama-2-7b 2bit
Model Type Q2_K_pure Q2_K_mixed INT_N (quantized) Q2_K_pure (F16) Q2_K_mixed (F16) INT_N (2bit)
Use T-MAC x x x x x
Threads 1 4.33 ± 0.10 3.69 ± 0.06 6.42 ± 0.34 4.31 ± 0.03 3.70 ± 0.02 6.10 ± 0.18
2 8.44 ± 0.20 7.10 ± 0.14 11.09 ± 0.61 8.11 ± 0.13 6.69 ± 0.10 11.03 ± 0.58
3 11.39 ± 0.29 9.63 ± 0.28 13.92 ± 1.14 10.94 ± 0.19 9.38 ± 0.26 13.09 ± 0.78
4 13.70 ± 0.38 11.59 ± 0.29 14.83 ± 1.60 13.00 ± 0.27 11.32 ± 0.28 13.97 ± 1.05
Intel i7-12700 llama-2-7b 4bit
Model Type Q4_0 Q4_0 INT_N (quantized) Q4_0 (F16) Q4_0 (F16) INT_N (f16)
Use T-MAC x x
Threads 1 2.92 ± 0.04 3.13 ± 0.07 3.18 ± 0.08 2.86 ± 0.04 3.10 ± 0.01 3.49 ± 0.05
2 5.59 ± 0.21 5.81 ± 0.19 5.96 ± 0.43 5.31 ± 0.07 5.64 ± 0.02 6.54 ± 0.26
3 7.60 ± 0.54 7.80 ± 0.30 7.80 ± 0.53 7.11 ± 0.30 7.38 ± 0.18 8.70 ± 0.35
4 8.97 ± 0.59 9.16 ± 0.47 9.28 ± 0.69 8.48 ± 0.19 8.39 ± 0.07 9.79 ± 0.47
Apple M2-Ultra TriLM-3.9b 2bit
Model Type TQ1_0 TQ2_0 TQ1_0 TQ2_0
Use T-MAC x x
Threads 1 9.46 ± 0.04 14.52 ± 0.13 22.09 ± 0.22 21.46 ± 0.27
2 17.58 ± 0.15 24.60 ± 0.28 32.81 ± 0.09 32.99 ± 0.12
3 24.11 ± 0.01 35.06 ± 0.27 44.62 ± 0.15 44.80 ± 0.23
4 30.57 ± 0.88 44.44 ± 0.19 54.77 ± 0.69 55.14 ± 0.06
Intel i7-12700 TriLM-3.9b 2bit
Model Type TQ1_0 TQ2_0 TQ1_0 TQ2_0
Use T-MAC x x
Threads 1 6.78 ± 0.18 10.11 ± 0.10 10.36 ± 0.21 11.08 ± 0.08
2 12.76 ± 0.67 19.02 ± 0.18 19.30 ± 0.22 19.40 ± 0.32
3 17.74 ± 1.47 25.50 ± 0.24 26.33 ± 0.07 26.31 ± 0.23
4 21.34 ± 2.06 29.33 ± 0.29 30.09 ± 0.75 29.86 ± 0.69

@JohannesGaessler
Copy link
Collaborator

Thanks for the numbers. I was specifically asking because I primarily work on GPU code and there the constraints are very different. In particular, on GPUs the main memory is comparatively small vs. the amount of available compute. If I read the numbers correctly, EQAT does not compress the data as efficiently as the quantization methods on master (and would thus not be a good fit for GPUs).

@QingtaoLi1
Copy link
Author

QingtaoLi1 commented Nov 12, 2024

EQAT does not compress the data as efficiently as the quantization methods on master (and would thus not be a good fit for GPUs).

@JohannesGaessler This PR is mainly to add support for LUT-based matrix multiplication kernel library which aims to speed up the CPU inference process.

  1. Our LUT-based kernel library can now support Q4_0 and TQ data types, and the numbers I posted yesterday showed that the accuracy is the same for the same quantization method (Q4_0, TQ1_0, TQ2_0) using T-MAC and origin llama.cpp. The EQAT algorithm is orthogonal to the LUT method; we mention it because it can be supported easily and has potential to gain better PPL and/or speed.
  2. Our LUT-based kernel library supports CPU for now, and are working on NPU support.
  3. The model size looks larger because we are using F16 embedding and output weights in the table. They can easily be quantized to master data types using llama-quantize. For example, if we use the same embedding and output weight types, the EQAT-2bit model size will shrink to almost the same as default Q2_K. I have updated the numbers above, adding the quantized embedding models.
Model size
Q2_K_mixed (default embed and output) 2.36 GiB
Q2_K_mixed (F16) 2.71 GiB
EQAT-w2g64-INT_N (default embed and output) 2.40 GiB
EQAT-w2g64-INT_N (F16) 2.75 GiB

@netrunnereve
Copy link
Collaborator

Okay I see it more clearly now, assuming we're using regular Q4_0 on the i7 with Llama 2 7B and 4 threads we get 8.97 t/s with T-MAC off and 9.16 t/s with T-MAC on (2% improvement).

Keep in mind that while the INT_N perplexity looks better it's using a specially prepared QAT model. So basically with QAT we can use one scale per 128 weights and get 9.28 t/s (additional 1% improvement). However we have our K-quants and something like Q4_K_M has perplexity of 5.877 on Llama 2 7B which slightly beats the 5.884 of INT_N. And that's with no QAT needed and basically the same or better performance compared to Q4_0.

On the other hand it looks like there are some genuine performance improvements on the 2 bit side though, though perplexity is higher than Q2_K. For 4 bit whether we have T-MAC or EQAT I honestly don't think this method is worth it.

@QingtaoLi1
Copy link
Author

QingtaoLi1 commented Nov 13, 2024

@netrunnereve Thanks for your detailed comment!

Okay I see it more clearly now, assuming we're using regular Q4_0 on the i7 with Llama 2 7B and 4 threads we get 8.97 t/s with T-MAC off and 9.16 t/s with T-MAC on (2% improvement).

Explain the speedup. We are proposing a brand new way to calculate low-bit matmul, using another set of instructions to implement. Memory bandwidth, CPI ratio between the LUT instructions and MUL instructions in different CPUs, both may affect the speedup ratio of LUT over MUL. This may be different from many great optimizations you've made inside MUL method.

Our LUT method or any other method cannot beat others when it goes to memory bound. As you can see, LUT beats existing master data type (i.e. quantization type) by a larger percentage in single thread cases. Also welcome to check the T-MAC repo for more numbers. On edge devices, LUT will gain even higher speedup.

Keep in mind that while the INT_N perplexity looks better it's using a specially prepared QAT model. So basically with QAT we can use one scale per 128 weights and get 9.28 t/s (additional 1% improvement). ...... And that's with no QAT needed and basically the same or better performance compared to Q4_0.

  • Better in existing data type and quantized models. INT_N get the same perplexity, much faster speed in TQ types (and probably Q2 types) and relatively-small-but-not-none speedup in Q4 type. I understand your observation that in a specific case, the number may look not so excited, but I think existing results have already sufficiently showed that LUT method outperforms MUL method in broad scenarios, even in Q4_0 with small block_size=32, where T-MAC cannot show its full strength.
  • We're not proposing EQAT quantization algorithm, but why minding a zoo of models with potential better performance? They are just standing there, easily being downloaded and converted to proper format and used. If I understand correctly, current master data types do not cover those methods. Our work can help extend the border of llama.cpp.

However we have our K-quants and something like Q4_K_M has perplexity of 5.877 on Llama 2 7B which slightly beats the 5.884 of INT_N.

Clarify the comparison. I have to note that the comparison between Q4_K_M=5.877 and INT_N=5.884 is not fair. Q4_K_M uses Q6_K in some weights, so the overall bpw grows to over 5. A fair competitor should be Q4_K_S.

And I don't know yet why my result is different from that document. In my test, Llama-2-7b F16 model gives 5.7969 PPL and Q2_K_M gives 6.982406, but in https://github.com/ggerganov/llama.cpp/blob/master/examples/perplexity/README.md Q2_K_M is 5.794552 which is even lower than Q4_K and Q6_K. I think you may double-check the numbers in that document.

On the other hand it looks like there are some genuine performance improvements on the 2 bit side though, though perplexity is higher than Q2_K. For 4 bit whether we have T-MAC or EQAT I honestly don't think this method is worth it.

Thanks for your affirmation on 2bit side. I guess you may misunderstand the proposed INT_N type? (Correct me if I'm wrong!)

  • INT_N is our proposed data type. It is needed to support T-MAC calculation method. Models quantized by EQAT algorithm is only our example to show the ability of T-MAC.
  • The master data types are bound with certain quantization methods, while INT_N not. For example, Q2_K is both a data type with bpw=2.625 and a quantization method to do rounding every 16 weights to 2bit, and then another rounding every 16 blocks to 4bit. But INT_N is able to cover a wide range of quantization methods whichever can be converted to 32/64/128/256-size rouding block in 1/2/3/4 and even more bits. EQAT is only an example.
  • There's almost no extra effort to support 1~4 bits in T-MAC. This is a unified method to support all low-bit, so if you support 2bit, you support 4bit simultaneously. And people can freely choose whether to use T-MAC on their certain scenario.

@netrunnereve
Copy link
Collaborator

Better in existing data type and quantized models. INT_N get the same perplexity, much faster speed in TQ types (and probably Q2 types) and relatively-small-but-not-none speedup in Q4 type.

I won't comment on the TQ types as I'm unfortunately not familiar enough with the implementation and quant methods. For 2-bit there's a good performance increase compared to 2-bit but perplexity is higher (and that's with QAT already). For Q4_0 the current AVX2 implementation you're likely using on your i7 is... suboptimal to say the least, and I think fully optimized it'll perform very similarly to your Q4_0 T-MAC.

We're not proposing EQAT quantization algorithm, but why minding a zoo of models with potential better performance? They are just standing there, easily being downloaded and converted to proper format and used.

INT_N is our proposed data type. It is needed to support T-MAC calculation method. Models quantized by EQAT algorithm is only our example to show the ability of T-MAC.

The issue here is that it becomes less of a fair comparison considering that our quants don't require EQAT. Only the creator of EQAT is uploading models and the selection is limited compared to the thousands of interesting finetunes on Hugging Face which work great with our K and I quants. It's a good idea to support QAT, but that makes the perplexity comparisons sort of unfair.

And I don't know yet why my result is different from that document. In my test, Llama-2-7b F16 model gives 5.7969 PPL and Q2_K_M gives 6.982406, but in https://github.com/ggerganov/llama.cpp/blob/master/examples/perplexity/README.md Q2_K_M is 5.794552 which is even lower than Q4_K and Q6_K. I think you may double-check the numbers in that document.

Yeah that Q2_K_M number is definitely wrong as it's lower than the Q6_K and Q8_0 results in the same table. When using a Q2_K 7B model the loss of quality is extremely obviously just from looking at the generated text. It's probably best to publish and trust your own benchmarks in this case.

I have to note that the comparison between Q4_K_M=5.877 and INT_N=5.884 is not fair. Q4_K_M uses Q6_K in some weights, so the overall bpw grows to over 5. A fair competitor should be Q4_K_S.

I'm not trying to be nitpicky here and I apologize if I sound that way, I'm just a bit skeptical about the claims. I think it would be a good idea to run some benchmarks and perplexity against our SOTA quants with similar bpw, rather than comparing with Q4_0 and Q2_K which really aren't used nowadays. For 4-bit that's probably Q4_K_S (yeah Q4_K_M is a bit too large), IQ4_NL, and IQ4_XS. For 2-bit Q2_K has been superseded by the I-quants.

@QingtaoLi1
Copy link
Author

QingtaoLi1 commented Nov 19, 2024

@netrunnereve I've tested Q4_K_S, IQ4_NL and IQ4_XS. Their PPLs are all larger than 5.88 while Q4_K_S has 8 Q5_K tensors and IQ4_XX have 4 Q5_K tensors. IQ4_XS performs the best among the three with smaller model size and best perplexity.

For speed, Q4_K_S > IQ4_XS > IQ4_NL. In the 4 thread scenario which you are concerned, compared to INT_N, Q4_K_S is 6% faster on my i7-12700 and 2% slower on M2-Ultra. IQ4_XS is almost the same on i7-12700 while 22% slower on M2-Ultra.

So according to the perplexity and speed numbers, I think Q4_K_S and IQ4_XS are two Pareto front points with different trade-off policies, and IQ4_NL is covered by IQ4_XS. The INT_N model I tested is not worse than IQ4_XS in these aspects and faster on M2, while almost not worse than Q4_K_S in these aspects but better in perplexity. So it almost or is close to cover the two master types.

Model - Llama-2-7b Model Size Mean PPL(Q) Mean ln(PPL(Q)/PPL(base)) Mean KLD RMS Δp Same top p
Q4_0 3.56 GiB 5.962298 ± 0.033478 0.028586 ± 0.000668 0.028983 ± 0.000299 4.757 ± 0.032 % 92.323 ± 0.065 %
EQAT-w4g128-INT_N (quantized) 3.56 GiB 5.884062 ± 0.032854 0.015378 ± 0.000692 0.031924 ± 0.000259 4.972 ± 0.036 % 92.167 ± 0.066 %
Q4_K_S 3.58 GiB 5.969718 ± 0.033393 0.029830 ± 0.000644 0.028650 ± 0.000189 4.929 ± 0.034 % 92.313 ± 0.065 %
IQ4_NL 3.58 GiB 5.919791 ± 0.033146 0.021432 ± 0.000594 0.022300 ± 0.000286 4.207 ± 0.031 % 93.308 ± 0.061 %
IQ4_XS 3.40 GiB 5.890478 ± 0.032849 0.016468 ± 0.000535 0.019815 ± 0.000158 4.021 ± 0.028 % 93.487 ± 0.060 %

And I have some tests on Raspberry Pi. The peak speed of INT_N is about 20% faster than Q4_0 (3.29 tokens/s V.S. 2.73 tokens/s) and about 50% faster than Q2_K_S (4.98 tokens/s V.S. 3.26 tokens/s). Edge devices are more bound on computation than memory, so the gap becomes obvious.

@scarlett2018
Copy link

@netrunnereve any comments or suggestions on this test results are highly appreciated!

@netrunnereve I've tested Q4_K_S, IQ4_NL and IQ4_XS. Their PPLs are all larger than 5.88 while Q4_K_S has 8 Q5_K tensors and IQ4_XX have 4 Q5_K tensors. IQ4_XS performs the best among the three with smaller model size and best perplexity.

For speed, Q4_K_S > IQ4_XS > IQ4_NL. In the 4 thread scenario which you are concerned, compared to INT_N, Q4_K_S is 6% faster on my i7-12700 and 2% slower on M2-Ultra. IQ4_XS is almost the same on i7-12700 while 22% slower on M2-Ultra.

So according to the perplexity and speed numbers, I think Q4_K_S and IQ4_XS are two Pareto front points with different trade-off policies, and IQ4_NL is covered by IQ4_XS. The INT_N model I tested is not worse than IQ4_XS in these aspects and faster on M2, while almost not worse than Q4_K_S in these aspects but better in perplexity. So it almost or is close to cover the two master types.

Model - Llama-2-7b Model Size Mean PPL(Q) Mean ln(PPL(Q)/PPL(base)) Mean KLD RMS Δp Same top p
Q4_0 3.56 GiB 5.962298 ± 0.033478 0.028586 ± 0.000668 0.028983 ± 0.000299 4.757 ± 0.032 % 92.323 ± 0.065 %
EQAT-w4g128-INT_N (quantized) 3.56 GiB 5.884062 ± 0.032854 0.015378 ± 0.000692 0.031924 ± 0.000259 4.972 ± 0.036 % 92.167 ± 0.066 %
Q4_K_S 3.58 GiB 5.969718 ± 0.033393 0.029830 ± 0.000644 0.028650 ± 0.000189 4.929 ± 0.034 % 92.313 ± 0.065 %
IQ4_NL 3.58 GiB 5.919791 ± 0.033146 0.021432 ± 0.000594 0.022300 ± 0.000286 4.207 ± 0.031 % 93.308 ± 0.061 %
IQ4_XS 3.40 GiB 5.890478 ± 0.032849 0.016468 ± 0.000535 0.019815 ± 0.000158 4.021 ± 0.028 % 93.487 ± 0.060 %
And I have some tests on Raspberry Pi. The peak speed of INT_N is about 20% faster than Q4_0 (3.29 tokens/s V.S. 2.73 tokens/s) and about 50% faster than Q2_K_S (4.98 tokens/s V.S. 3.26 tokens/s). Edge devices are more bound on computation than memory, so the gap becomes obvious.

@netrunnereve
Copy link
Collaborator

Sorry, I missed your comment! In this case I can now see that EQAT-w4g128-INT_N is beating the K and I quants of similar size in terms of perplexity and speed, on the condition that the model is finetuned with EQAT. Honestly I'm not sure how much interest this project has in supporting a specific type of QAT'd model.

As this is a new quant type with an associated maintenance requirement you'll probably need the core owners like GG or slaren to look into the viability of accepting this PR.

@scarlett2018
Copy link

Thanks @netrunnereve for the feedback on the new results, appreciate for suggesting connect the core owners.

@ggerganov @slaren - may you review and suggest the viability of accepting this PR? If maintenance is a concern, we can help on the maintenance of this feature.

Sorry, I missed your comment! In this case I can now see that EQAT-w4g128-INT_N is beating the K and I quants of similar size in terms of perplexity and speed, on the condition that the model is finetuned with EQAT. Honestly I'm not sure how much interest this project has in supporting a specific type of QAT'd model.

As this is a new quant type with an associated maintenance requirement you'll probably need the core owners like GG or slaren to look into the viability of accepting this PR.

@slaren
Copy link
Collaborator

slaren commented Dec 6, 2024

Generally my opinion is that if there are some cases where these types perform better than any others, it would be good to merge. From what I understand from the discussion here, that seems to be the case.

It looks like the code at the moment depends on a external library, which could be a problem. Is the intention to add the library code here?

@QingtaoLi1
Copy link
Author

@slaren Thanks for your positive reply!

There are two parts to the T-MAC code:

  1. The code in this pull request. This introduces necessary interfaces, helper functions, and a new data type for the LUT-based mul_mat method.
  2. The code in the T-MAC repo. This handles the search and generation of the LUT-based mul_mat kernels and wrappers.

The second part will be an external library. The kernels and wrappers provide an alternative implementation of mul_mat for certain low-bit types (see ggml-cpu.c for the replacement). Once the kernels are generated, the T-MAC module will no longer be involved in the subsequent llama.cpp build and runtime process.

However, the second part of the code is still necessary for general use. The kernels may vary depending on factors like weight bits, mul_mat shapes, and devices. Even when kernels are portable across devices with the same ISA, customizing the kernels for the new device will lead to better performance.

We see two possible ways to integrate T-MAC:

  1. As a third-party submodule that references the T-MAC repo (similar to OpenBLAS). This would require less integration effort and make it easier to track new updates.
  2. As a new backend, placing the code under ggml/src/ggml-tmac or another appropriate directory.

Which approach would you prefer?

@slaren
Copy link
Collaborator

slaren commented Dec 10, 2024

As a rule of thumb, we should not add dependencies to 3rd party libraries to the ggml code. Backends are an exception since adding 3rd party libraries is usually unavoidable in that case. However, adding a new backend for the T-MAC types would not be a good solution either, we should keep all the CPU code in the CPU backend. This is especially important due to recent changes that add the ability to load backends dynamically. A binary package of llama.cpp may bundle multiple versions of the CPU backend compiled with different instruction sets, and we do not want to also have to bundle multiple versions of a T-MAC backend.

My recommendation would be the following:

  • Remove the 3rd party library dependency and add all the code here
  • If the implementation needs to do some kind of repacking of the tensor data optimized for the current CPU, do so using the interfaces added in Refactor/online repacking #10446 for this purpose

@QingtaoLi1
Copy link
Author

@slaren Thanks for your constructive reply!

Since "the 3rdparty code" is implemented in Python, we have to find a proper way to place it and to add in the existing workflow if we add all the codes here. We see it possible to treat them as gguf-py and convert_xxx.py by adding a new Python folder and making it an independent Python command to run before building llama.cpp. Does it meet your rules? You can check our Python entry for more details if you'd like.

And we have another two questions:

  1. We currently have tvm as our 3rdparty dependency. Is it a must to remove this dependency if we add all the codes here?
  2. Our LUT-based method dynamically generates the mul_mat cpp code. Do we need to take it into consideration that some users may use llama.cpp binary package instead of build from source?

@slaren
Copy link
Collaborator

slaren commented Dec 14, 2024

Thanks for the explanation. I see now that I severely underestimated how complicated it would be to bring this code to ggml.

Adding Apache TVM as a dependency of ggml is not a possibility. Using python scripts to generate the kernels may be ok depending on the circumstances, but if what this means is that every model needs a different set of kernels, that is likely too far. We absolutely need to be able to distribute binary packages of llama.cpp, since it needs to be able to run on edge/final user devices.

I can give you a few pointers, but realistically, I don't see how you could bring all this system into ggml in a way that integrates with the existing code, and would not become an unreasonable maintenance burden, without effectively rewriting large parts of it. At this moment I cannot commit the time that would be necessary to even figure how to fit all of this into the existing ggml code.

@Zant12
Copy link

Zant12 commented Dec 17, 2024

Support NPU inference acceleration.

@QingtaoLi1 The NPU support sounds really promising. I have tested the snapdragon 8 elite NPU (supports float16, INT8, INT4) locally with their QNN genie bundle framework, and found the prompt processing performance is more than a magnitude greater than the decode speed, around seven-hundred tokens per second for 7B at 4 bit (40x). The NPU is faster than GPU acceleration, which in my tests only have ~3:1 prompt processing to decoding.

Were you thinking of Apple Silicon NPU hardware?

@QingtaoLi1
Copy link
Author

QingtaoLi1 commented Dec 17, 2024

@slaren Thanks, I get the points. We have discussed about this conflict given llama.cpp community's rules. There is a possible plan to meet an agreement:

  1. Remove the tvm dependency.
  2. Prepare a set of static kernels for x86 and ARM, and support different shapes and models by composing these static kernels.

As you pointed out, this plan does need to re-write a large part of T-MAC code.

@QingtaoLi1
Copy link
Author

@Zant12 We are working on Qualcomm NPU, and currently no plan for Apple Silicon NPU.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants